# This is the experimental code for the event understanding and intent inference tasks with the VisualCOMET model.
# !/usr/bin/env python3
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import json
import logging
import argparse
import pickle
import sys

import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from torch.utils.data import SequentialSampler
from torch.utils.data import DataLoader

from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig

import nltk
from nltk.stem.porter import PorterStemmer
from nltk.corpus import wordnet
from itertools import chain, product

from dataloaders.tokenizers import VisualCometTokenizer
from dataloaders.vcg_generation import VCGGenDataset
from models.model import GPT2VisionAttentiveLMHead
from utils.file_utils import replace_names

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in
                  (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())

MODEL_CLASSES = {
    'gpt2_vc': (GPT2Config, GPT2VisionAttentiveLMHead, VisualCometTokenizer),
}

def main():
    def _prompt_to_gen(context_orig, img_feats, boxes, boxes_mask, objects, segments, person_ids, subject_ids):

        set_seed(args)

        if args.model_type in ["transfo-xl", "xlnet"]:
            # Models with memory likes to have a long prompt for short inputs.
            raise Exception("Not supported")

        context = context_orig.repeat(args.num_samples, 1)
        text_gen = [{} for _ in range(args.num_samples)]

        # set the start token to signal when to start generation
        pad_token = tokenizer.convert_tokens_to_ids([tokenizer.unk_token])[0]
        possible_inferences = [tokenizer.convert_tokens_to_ids([br])[0] for br in tokenizer.begin_inferences.values()]
        begin_inference = [r for r in possible_inferences if r in context_orig]
        assert len(begin_inference) == 1
        prompt_token_idx = begin_inference[0]
        end_token = tokenizer.convert_tokens_to_ids([tokenizer.end_inference])[0]

        # cut off
        idx_of_prompt_token = (context_orig == prompt_token_idx).nonzero()[0][1].item()
        context[:,idx_of_prompt_token] = prompt_token_idx
        context_input = context[:, :idx_of_prompt_token + 1]

        # begin sampling sequence starting from context_input
        out = sample_sequence(
            model=model,
            context=context_input,
            pad_token=pad_token,
            end_token=end_token,
            img_feats=img_feats,
            boxes=boxes,
            boxes_mask=boxes_mask,
            objects=objects,
            segments=segments,
            length=args.length,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
            do_sample=args.do_sample,
            device=args.device,
            num_samples=args.num_samples,
            person_ids=person_ids,
            subject_ids=subject_ids
        )

        # ensure to end the sequence with end token, and pad the rest of the sequence.
        out = out[:, idx_of_prompt_token + 1:]
        out[:,-1] = end_token
        ending_idx = (out == end_token).nonzero()
        processed = []
        for i in range(ending_idx.size(0)):
            sample_idx = ending_idx[i][0].item()
            if sample_idx not in processed:
                processed.append(sample_idx)
                end_idx = ending_idx[i][1].item()
                if end_idx < out.size(1)-1:
                    out[sample_idx,end_idx+1:] = pad_token
                context_end_idx = idx_of_prompt_token+1+end_idx
                context[sample_idx,idx_of_prompt_token+1:context_end_idx+1] = out[sample_idx, :end_idx+1]
                if context_end_idx < context.size(1)-1:
                    context[sample_idx:,context_end_idx+1:] = pad_token

        # decode the sequence to text
        text_gen = [tokenizer.decode(o, skip_special_tokens=True, clean_up_tokenization_spaces=True) for o in out.tolist()]
        return text_gen

    output_name = args.output_file if args.output_file is not None else args.model_name_or_path + '/'
    output_file = '{}{}_sample_{}_num_{}_top_k_{}_top_p_{}.json'.format(
        output_name,
        args.split,
        args.do_sample,
        args.num_samples,
        args.top_k,
        args.top_p,
        )
    print(output_file)

    # Get Dataset Loader
    dataset = load_and_cache_vcg_examples(args, config, tokenizer)
    eval_dataset = dataset.get_dataset(args.split)
    all_records = eval_dataset.records
    eval_dataloader = DataLoader(eval_dataset, sampler=SequentialSampler(eval_dataset),
                                 batch_size=args.gen_batch_size)

    # event understanding
    bev_index = tokenizer.convert_tokens_to_ids(['<|b_ev|>'])[0]
    tokenizer.begin_inferences['event_before'] = '<|b_ev|>'
    possible_inferences = [tokenizer.convert_tokens_to_ids([br])[0] for br in tokenizer.begin_inferences.values()]
    
    output_keys = ['img_fn', 'movie', 'metadata_fn', 'split', 'place', 'event', 'inference_relation', 'event_idx', 'gt_event']
    results = []
    idx = 0
    context_inputs = set()
    for input_record, (
            data_input) in tqdm.tqdm(
        zip(all_records, eval_dataloader), total=len(all_records)):

        if args.include_image:
            inputs, img_feats, boxes, boxes_mask, objects, segments, person_ids, subject_ids = \
                [d.to(args.device) for d in data_input]
        else:
            inputs = data_input.to(args.device) # inputs, labels = [d.to(args.device) for d in data_input]
            img_feats, boxes, boxes_mask, objects, segments, person_ids, subject_ids = [None] * 7

        # event understanding
        begin_inference = [r for r in possible_inferences if r in inputs]
        begin_inference_index = inputs[0].tolist().index(begin_inference[0])
        inputs[0][begin_inference_index] = bev_index
        
        # intent inference
        for i in range(len(det_false_no_conf)):
            if input_record['img_fn'] == det_false_no_conf[i]['img_fn']:
                gt_event = det_false_no_conf[i]['event']
                input_record['gt_event'] = gt_event
                best_score = 0.0
                for j in range(len(det_false_no_conf[i]['generations'])):
                    gen_event = det_false_no_conf[i]['generations'][j]
                    cur_score = nltk.translate.meteor_score.single_meteor_score(gt_event, gen_event)
                    if cur_score >= best_score:                                                                     
                        input_record['event'] = gen_event                                                           
                        best_score = cur_score

        keys = input_record['person2name'].keys()
        event = input_record['event']
        words = event.split()

        for k in range(len(words)):
            if words[k] in keys:
                words[k] = input_record['person2name'][words[k]]

        merge = ""
        for p in range(len(words)):
            if (p+1) == len(words):
                merge += words[p]
            else:
                merge += words[p]
                merge += " "
        input_record['event_name'] = merge                                                                           
        token = tokenizer.encode(merge)
        clone = inputs.detach().clone()
        arr = []
        for i in range(len(clone[0])):
            arr.append(clone[0][i])

        ev_index = []
        for i in range(len(inputs[0])):
            if inputs[0][i] == 50259:
                ev_index.append(i)
            if inputs[0][i] == 50260:
                ev_index.append(i)
        num = ev_index[1] - ev_index[0] - 1

        rep_start = ev_index[0] + 1
        rep_end = ev_index[0] + len(token)
        end = ev_index[1]

        new_num = len(token)
        if num > new_num:
            tmp = len(inputs[0])-(end+1)
            for j in range(len(token)):
                inputs[0][rep_start+j] = token[j]
            inputs[0][rep_end+1] = 50260
            for p in range(tmp):
                inputs[0][rep_end+2+p] = arr[end+1+p]
            for k in range((num-new_num)):
                inputs[0][tmp+k] = 50256

        if num == new_num:
            for j in range(len(token)):
                inputs[0][rep_start+j] = token[j]

        if num < new_num:
            tmp = len(inputs[0])-(rep_end+2)
            for p in range(tmp):
                inputs[0][rep_end+2+p] = arr[end+1+p]
            for j in range(len(token)):
                inputs[0][rep_start+j] = token[j]
            inputs[0][rep_end+1] = 50260

        #remove confounder
        metadata_path = os.path.join('/workspace/vcr1images',input_record['metadata_fn'])
        with open(metadata_path,'r') as fp:
            metadata = json.load(fp)
        if (len(metadata['names']) > 15):
            obj_names = metadata['names'][:15]
        else:
            obj_names = metadata['names']
        for name_i in range(len(obj_names)):
            object_i = name_i+1
            if obj_names[name_i] in confounder and (object_i<=14):
                img_feats[0,object_i,:] = 0
                boxes[0,object_i,:] = 0
                boxes_mask[0,object_i] = 0
                objects[0,object_i] = 0
                segments[0,object_i,:,:] = 0
                subject_ids[0,object_i] = 0

        # Now, generate the inferences and decode using original ids.
        generations = _prompt_to_gen(inputs, img_feats, boxes, boxes_mask, objects, segments, person_ids, subject_ids)
        for i in range(len(generations)):
            generations[i] = replace_names(generations[i], input_record['name2person'])
        output_record = {k: input_record[k] for k in output_keys}
        output_record['generations'] = generations
        results.append(output_record)

    json.dump(results, open(output_file,'w'))
    print('Saved to', output_file)

if __name__ == '__main__':
    main()

